# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks
# Target: verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6
# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
FROM nvcr.io/nvidia/pytorch:24.08-py3

# Define environments
ENV MAX_JOBS=16
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
ENV DEBIAN_FRONTEND=noninteractive
ENV NODE_OPTIONS=""
ENV PIP_ROOT_USER_ACTION=ignore
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

# Define installation arguments
ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple

# Set apt source
RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
    { \
    echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \
    echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \
    echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \
    echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \
    } > /etc/apt/sources.list

# Install systemctl
RUN apt-get update && \
    apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \
    apt-get clean

# Install tini
RUN apt-get update && \
    apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \
    apt-get clean

# Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \
    pip config set global.extra-index-url "${PIP_INDEX}" && \
    python -m pip install --upgrade pip

# Uninstall nv-pytorch fork
RUN pip uninstall -y torch torchvision torchaudio \
    pytorch-quantization pytorch-triton torch-tensorrt \
    xgboost transformer_engine flash_attn apex megatron-core grpcio

RUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1

# Install flash-attn-2.7.4.post1, although built with torch2.6, it is compatible with torch2.7
# https://github.com/Dao-AILab/flash-attention/issues/1644#issuecomment-2899396361
RUN ABI_FLAG=$(python -c "import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')") && \
    URL="https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl" && \
    FILE="flash_attn-2.7.4.post1+cu12torch2.6cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl" && \
    wget -nv "${URL}" && \
    pip install --no-cache-dir "${FILE}"

# Fix packages
RUN pip uninstall -y pynvml nvidia-ml-py && \
    pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1"

# Install cudnn
RUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \
    dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \
    cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \
    apt-get update && \
    apt-get -y install cudnn-cuda-12 && \
    rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb

# Install Apex
RUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" --resume-retries 999 git+https://github.com/NVIDIA/apex.git

# Profiling tools
RUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \
    apt-get update && apt-get install -y libxcb-cursor0

RUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \
    rm -rf /usr/local/cuda/bin/nsys && \
    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys  /usr/local/cuda/bin/nsys && \
    rm -rf /usr/local/cuda/bin/nsys-ui && \
    ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \
    rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb

RUN pip install --resume-retries 999 --no-cache-dir "tensordict==0.6.2" torchdata "transformers[hf_xet]>=4.52.3" accelerate datasets peft hf-transfer \
    "numpy<2.0.0" "pyarrow>=19.0.1" pandas cuda-bindings \
    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \
    pytest py-spy pyext pre-commit ruff

# Install DeepEP
## the dependency of IBGDA
RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so

## Clone and build deepep and deepep-nvshmem
RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \
    git clone https://github.com/deepseek-ai/DeepEP.git  && \
    cd DeepEP && git checkout a84a248

# Prepare nvshmem
RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \
    tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \
    cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch

ENV CUDA_HOME=/usr/local/cuda
### Set MPI environment variables. Having errors when not set.
ENV CPATH=/usr/local/mpi/include:$CPATH
ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH
ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH
ENV GDRCOPY_HOME=/workspace/gdrcopy

## Build deepep-nvshmem
RUN cd deepep-nvshmem && \
    NVSHMEM_SHMEM_SUPPORT=0 \
    NVSHMEM_UCX_SUPPORT=0 \
    NVSHMEM_USE_NCCL=0 \
    NVSHMEM_MPI_SUPPORT=0 \
    NVSHMEM_IBGDA_SUPPORT=1 \
    NVSHMEM_PMIX_SUPPORT=0 \
    NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
    NVSHMEM_USE_GDRCOPY=1 \
    cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install

ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install
ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH
ENV PATH=$NVSHMEM_DIR/bin:$PATH

## Build deepep
RUN cd DeepEP && \
    python setup.py install

# Reset pip config
RUN pip config unset global.index-url && \
    pip config unset global.extra-index-url

